{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a34fe42d-86e4-4e86-a91b-89363cb31394",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-12-25 03:13:49,134] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "from trl import SFTTrainer, DataCollatorForCompletionOnlyLM\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments,AutoConfig\n",
    "from datasets import Dataset\n",
    "import torch\n",
    "import bitsandbytes as bnb\n",
    "from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "226b20c7-6cfb-40d8-adb5-eb5db4da9799",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import requests\n",
    "from datasets import Dataset\n",
    "\n",
    "train_df = pd.read_csv('train.csv')\n",
    "train_dataset = Dataset.from_pandas(train_df)\n",
    "validation_df = pd.read_csv('val.csv')\n",
    "validation_dataset = Dataset.from_pandas(validation_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2e2ce98a-8365-4e77-ae26-69ea968a6369",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb, os\n",
    "wandb_project = \"mix-finetune\"\n",
    "if len(wandb_project) > 0:\n",
    "    os.environ[\"WANDB_PROJECT\"] = wandb_project"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65cf322f-4791-4ea2-b9be-81c4325738e6",
   "metadata": {},
   "source": [
    "# Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3eab2e43-8b5c-4677-8476-c854f5ef65bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9dd2cfa5862b45cb830a96e56190b456",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    ")\n",
    "config = AutoConfig.from_pretrained('mistralai/Mixtral-8x7B-v0.1')\n",
    "config.use_cache = False\n",
    "config.gradient_checkpointing = True\n",
    "torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n",
    "\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained('mistralai/Mixtral-8x7B-v0.1',\n",
    "                                             config=config,\n",
    "                                             quantization_config=bnb_config,\n",
    "                                             trust_remote_code=False,\n",
    "                                             torch_dtype=torch_dtype,\n",
    "                                             device_map=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3c2e8cf0-8d48-422d-98ac-4f283685b200",
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_token_id = 0\n",
    "tokenizer = AutoTokenizer.from_pretrained('mistralai/Mixtral-8x7B-v0.1',\n",
    "                                          trust_remote_code=False,\n",
    "                                          use_fast=True)\n",
    "tokenizer.pad_token_id = pad_token_id\n",
    "tokenizer.pad_token = tokenizer.convert_ids_to_tokens(pad_token_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87bd46ed-a8d0-4c85-a906-4739897c28e3",
   "metadata": {},
   "source": [
    "# Qlora configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "854ab493-1013-4b2d-901c-a5beb6fd8c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_all_linear_names(model, add_lm_head=True):\n",
    "    cls = bnb.nn.Linear4bit\n",
    "    lora_module_names = set()\n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, cls):\n",
    "            names = name.split('.')\n",
    "            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n",
    "\n",
    "    if add_lm_head and not \"lm_head\" in lora_module_names:\n",
    "        lora_module_names.add(\"lm_head\")\n",
    "\n",
    "    return list(lora_module_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "54101cbf-8a13-4ce0-b11a-82c6e97256a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['k_proj', 'w3', 'v_proj', 'gate', 'lm_head', 'q_proj', 'w2', 'w1', 'o_proj']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_modules = find_all_linear_names(model)\n",
    "target_modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "31e06f66-9b7c-4fd9-a618-1fdb0c51873d",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################################\n",
    "# QLoRA parameters\n",
    "################################################################################\n",
    "\n",
    "# LoRA attention dimension\n",
    "lora_r = 64\n",
    "\n",
    "# Alpha parameter for LoRA scaling\n",
    "lora_alpha = 16\n",
    "\n",
    "# Dropout probability for LoRA layers\n",
    "lora_dropout = 0.1\n",
    "\n",
    "peft_config = LoraConfig(\n",
    "    r=lora_r,\n",
    "    lora_alpha=lora_alpha,\n",
    "    lora_dropout=lora_dropout,\n",
    "    target_modules=target_modules,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\",\n",
    "    inference_mode=False\n",
    ")\n",
    "model = prepare_model_for_kbit_training(model,\n",
    "                                        use_gradient_checkpointing=True)\n",
    "\n",
    "model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b2d4d07-6dc1-4bbe-b070-7621c818ef3e",
   "metadata": {},
   "source": [
    "# define training arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "352949fc-ad8a-4b51-9eb1-464645ff8d62",
   "metadata": {},
   "outputs": [],
   "source": [
    "################################################################################\n",
    "# TrainingArguments parameters\n",
    "################################################################################\n",
    "\n",
    "# Number of training epochs\n",
    "num_train_epochs = 3\n",
    "\n",
    "# Enable fp16/bf16 training (set bf16 to True with an A100)\n",
    "fp16 = True\n",
    "bf16 = False\n",
    "\n",
    "# Batch size per GPU for training\n",
    "per_device_train_batch_size = 2\n",
    "\n",
    "# Number of update steps to accumulate the gradients for\n",
    "gradient_accumulation_steps = 2\n",
    "\n",
    "# Initial learning rate (AdamW optimizer)\n",
    "learning_rate = 1e-4\n",
    "\n",
    "# Optimizer to use\n",
    "optim = \"paged_adamw_8bit\"\n",
    "\n",
    "# Log every X updates steps\n",
    "logging_steps = 10\n",
    "\n",
    "eval_steps = 10\n",
    "save_steps = 10\n",
    "logging_steps = 10\n",
    "lr_scheduler_type = \"constant\"\n",
    "warmup_steps = 50\n",
    "gradient_checkpointing = True\n",
    "weight_decay = 0.05\n",
    "save_total_limit = 3\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    do_train=True,\n",
    "    do_eval=True,\n",
    "    output_dir=\"./checkpoints-3\",\n",
    "    dataloader_drop_last=True,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    logging_strategy=\"steps\",\n",
    "    num_train_epochs=num_train_epochs,\n",
    "    eval_steps=eval_steps,\n",
    "    save_steps=save_steps,\n",
    "    logging_steps=logging_steps,\n",
    "    per_device_train_batch_size=per_device_train_batch_size,\n",
    "    per_device_eval_batch_size=per_device_train_batch_size*2,\n",
    "    optim=optim,\n",
    "    learning_rate=learning_rate,\n",
    "    lr_scheduler_type=lr_scheduler_type,\n",
    "    warmup_steps=warmup_steps,\n",
    "    gradient_accumulation_steps=gradient_accumulation_steps,\n",
    "    gradient_checkpointing=gradient_checkpointing,\n",
    "    weight_decay=weight_decay,\n",
    "    report_to=\"wandb\",\n",
    "    load_best_model_at_end=True,\n",
    "    save_total_limit=save_total_limit,\n",
    "    bf16=True if torch.cuda.is_bf16_supported() else False,\n",
    "    fp16=False if torch.cuda.is_bf16_supported() else True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2457d04-db09-4e64-bdd2-e20fbca47801",
   "metadata": {},
   "source": [
    "# define trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e6aad83b-7f0a-472f-ba1c-8df46ea984a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21146d82ff3d462f8a12586a8101a9a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/329 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65459583516d4633b77bbb8d929ec186",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/37 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:282: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "block_size = 1024\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=validation_dataset,\n",
    "    dataset_text_field=\"text\",\n",
    "    max_seq_length=block_size,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=None,\n",
    "    packing=None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7153b8c3-7986-4b06-999c-3d5a3679ecc0",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e17678b2-3ce3-464d-beea-da04a127a537",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "341b903f-ad4b-40e8-b58d-3af1593977f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model('mix-qlora-result')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
